{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Tce3stUlHN0L"
      },
      "source": [
        "##### Copyright 2019 The TensorFlow IO Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "tuOe1ymfHZPu"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qFdPvlXBOdUN"
      },
      "source": [
        "# BigQuery TensorFlow リーダーのエンドツーエンドの例"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MfBg1C5NB3X0"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td><a target=\"_blank\" href=\"https://www.tensorflow.org/io/tutorials/bigquery\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\"> TensorFlow.orgで表示</a></td>\n",
        "  <td><a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs-l10n/blob/master/site/ja/io/tutorials/bigquery.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\">Google Colabで実行</a></td>\n",
        "  <td><a target=\"_blank\" href=\"https://github.com/tensorflow/docs-l10n/blob/master/site/ja/io/tutorials/bigquery.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\">GitHub でソースを表示{</a></td>\n",
        "      <td><a href=\"https://storage.googleapis.com/tensorflow_docs/docs-l10n/site/ja/io/tutorials/bigquery.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\">ノートブックをダウンロード/a0}</a></td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xHxb-dlhMIzW"
      },
      "source": [
        "## 概要\n",
        "\n",
        "このチュートリアルでは、Keras シーケンス API を使用してニューラルネットワークをトレーニングするために [BigQuery TensorFlow リーダー](https://github.com/tensorflow/io/tree/master/tensorflow_io/bigquery)を使用する方法を説明します。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WodUv8O1VKmr"
      },
      "source": [
        "### データセット\n",
        "\n",
        "このチュートリアルでは、[カリフォルニア大学アーバイン校の機械学習リポジトリ](https://archive.ics.uci.edu/ml/index.php)が提供する[米国国勢調査所得データセット](https://archive.ics.uci.edu/ml/datasets/census+income)を使用します。このデータセットには 1994 年の国税調査データベースに記録される市民の年齢、学歴、婚姻状況、職業、および年収が 50,000 ドルを超えるかという情報が含まれています。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MUXex9ctTuDB"
      },
      "source": [
        "## セットアップ"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4YsfgDMZW5g6"
      },
      "source": [
        "GCP プロジェクトをセットアップします。\n",
        "\n",
        "**次の手順は、ノートブックの環境に関係なく必要な手順です。**\n",
        "\n",
        "1. [GCP プロジェクトを選択または作成します。](https://console.cloud.google.com/cloud-resource-manager)\n",
        "2. [プロジェクトの課金先が有効であることを確認します。](https://cloud.google.com/billing/docs/how-to/modify-project)\n",
        "3. [Enable the BigQuery Storage API](https://cloud.google.com/bigquery/docs/reference/storage/#enabling_the_api)\n",
        "4. 以下のセルにプロジェクト ID を入力してセルを実行し、このノートブックのすべてのコマンドにおいて、Cloud SDK が正しいプロジェクトを使用することを確認します。\n",
        "\n",
        "注意: Jupyter は、シェルコマンドとして接頭辞 `!` のある行を実行し、接頭辞 `$` のある Python 変数をこれらのコマンドに補間します。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "upgCc3gXybsA"
      },
      "source": [
        "必要なパッケージをインストールし、ランタイムを再起動します。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "QeBQuayhuvhg"
      },
      "outputs": [],
      "source": [
        "try:\n",
        "  # Use the Colab's preinstalled TensorFlow 2.x\n",
        "  %tensorflow_version 2.x \n",
        "except:\n",
        "  pass"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Tu01THzWcE-J"
      },
      "outputs": [],
      "source": [
        "!pip install fastavro\n",
        "!pip install tensorflow-io==0.9.0"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "YUj0878jPyz7"
      },
      "outputs": [],
      "source": [
        "!pip install google-cloud-bigquery-storage"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yZmI7l_GykcW"
      },
      "source": [
        "認証します。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "tUo3GJNxxZbc"
      },
      "outputs": [],
      "source": [
        "from google.colab import auth\n",
        "auth.authenticate_user()\n",
        "print('Authenticated')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "czF1KlC6y8fB"
      },
      "source": [
        "プロジェクト ID を設定します。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "fbQR-bT_xgba"
      },
      "outputs": [],
      "source": [
        "PROJECT_ID = \"<YOUR PROJECT>\" #@param {type:\"string\"}\n",
        "! gcloud config set project $PROJECT_ID\n",
        "%env GCLOUD_PROJECT=$PROJECT_ID"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hOqN91M4y_9Y"
      },
      "source": [
        "Python ライブラリをインポートして、定数を定義します。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "G3p2ICG80Z1t"
      },
      "outputs": [],
      "source": [
        "from __future__ import absolute_import, division, print_function, unicode_literals\n",
        "\n",
        "import os\n",
        "from six.moves import urllib\n",
        "import tempfile\n",
        "\n",
        "import numpy as np\n",
        "import pandas as pd\n",
        "import tensorflow as tf\n",
        "\n",
        "from google.cloud import bigquery\n",
        "from google.api_core.exceptions import GoogleAPIError\n",
        "\n",
        "LOCATION = 'us'\n",
        "\n",
        "# Storage directory\n",
        "DATA_DIR = os.path.join(tempfile.gettempdir(), 'census_data')\n",
        "\n",
        "# Download options.\n",
        "DATA_URL = 'https://storage.googleapis.com/cloud-samples-data/ml-engine/census/data'\n",
        "TRAINING_FILE = 'adult.data.csv'\n",
        "EVAL_FILE = 'adult.test.csv'\n",
        "TRAINING_URL = '%s/%s' % (DATA_URL, TRAINING_FILE)\n",
        "EVAL_URL = '%s/%s' % (DATA_URL, EVAL_FILE)\n",
        "\n",
        "DATASET_ID = 'census_dataset'\n",
        "TRAINING_TABLE_ID = 'census_training_table'\n",
        "EVAL_TABLE_ID = 'census_eval_table'\n",
        "\n",
        "CSV_SCHEMA = [\n",
        "      bigquery.SchemaField(\"age\", \"FLOAT64\"),\n",
        "      bigquery.SchemaField(\"workclass\", \"STRING\"),\n",
        "      bigquery.SchemaField(\"fnlwgt\", \"FLOAT64\"),\n",
        "      bigquery.SchemaField(\"education\", \"STRING\"),\n",
        "      bigquery.SchemaField(\"education_num\", \"FLOAT64\"),\n",
        "      bigquery.SchemaField(\"marital_status\", \"STRING\"),\n",
        "      bigquery.SchemaField(\"occupation\", \"STRING\"),\n",
        "      bigquery.SchemaField(\"relationship\", \"STRING\"),\n",
        "      bigquery.SchemaField(\"race\", \"STRING\"),\n",
        "      bigquery.SchemaField(\"gender\", \"STRING\"),\n",
        "      bigquery.SchemaField(\"capital_gain\", \"FLOAT64\"),\n",
        "      bigquery.SchemaField(\"capital_loss\", \"FLOAT64\"),\n",
        "      bigquery.SchemaField(\"hours_per_week\", \"FLOAT64\"),\n",
        "      bigquery.SchemaField(\"native_country\", \"STRING\"),\n",
        "      bigquery.SchemaField(\"income_bracket\", \"STRING\"),\n",
        "  ]\n",
        "\n",
        "UNUSED_COLUMNS = [\"fnlwgt\", \"education_num\"]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ooBfVnd-Xxd-"
      },
      "source": [
        "## 国税調査データを BigQuery にインポートする"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0qt6wUD_2XFT"
      },
      "source": [
        "データを BigQuery に読み込むヘルパーメソッドを定義します。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7mMI7uW_2vP5"
      },
      "outputs": [],
      "source": [
        "def create_bigquery_dataset_if_necessary(dataset_id):\n",
        "  # Construct a full Dataset object to send to the API.\n",
        "  client = bigquery.Client(project=PROJECT_ID)\n",
        "  dataset = bigquery.Dataset(bigquery.dataset.DatasetReference(PROJECT_ID, dataset_id))\n",
        "  dataset.location = LOCATION\n",
        "\n",
        "  try:\n",
        "    dataset = client.create_dataset(dataset)  # API request\n",
        "    return True\n",
        "  except GoogleAPIError as err:\n",
        "    if err.code != 409: # http_client.CONFLICT\n",
        "      raise\n",
        "  return False\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Y3Cr-DEfwNhK"
      },
      "outputs": [],
      "source": [
        "def load_data_into_bigquery(url, table_id):\n",
        "  create_bigquery_dataset_if_necessary(DATASET_ID)\n",
        "  client = bigquery.Client(project=PROJECT_ID)\n",
        "  dataset_ref = client.dataset(DATASET_ID)\n",
        "  table_ref = dataset_ref.table(table_id)\n",
        "  job_config = bigquery.LoadJobConfig()\n",
        "  job_config.write_disposition = bigquery.WriteDisposition.WRITE_TRUNCATE\n",
        "  job_config.source_format = bigquery.SourceFormat.CSV\n",
        "  job_config.schema = CSV_SCHEMA\n",
        "\n",
        "  load_job = client.load_table_from_uri(\n",
        "      url, table_ref, job_config=job_config\n",
        "  )\n",
        "  print(\"Starting job {}\".format(load_job.job_id))\n",
        "\n",
        "  load_job.result()  # Waits for table load to complete.\n",
        "  print(\"Job finished.\")\n",
        "\n",
        "  destination_table = client.get_table(table_ref)\n",
        "  print(\"Loaded {} rows.\".format(destination_table.num_rows))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qSA0RIAZZEFZ"
      },
      "source": [
        "国税調査データを BigQuery に読み込みます。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "wFZcK03-YDm4"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Starting job 2ceffef8-e6e4-44bb-9e86-3d97b0501187\n",
            "Job finished.\n",
            "Loaded 32561 rows.\n",
            "Starting job bf66f1b3-2506-408b-9009-c19f4ae9f58a\n",
            "Job finished.\n",
            "Loaded 16278 rows.\n"
          ]
        }
      ],
      "source": [
        "load_data_into_bigquery(TRAINING_URL, TRAINING_TABLE_ID)\n",
        "load_data_into_bigquery(EVAL_URL, EVAL_TABLE_ID)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "KpUVI8IR2iXH"
      },
      "source": [
        "データがインポートされたことを確認します。\n",
        "\n",
        "作業: <YOUR PROJECT> をご利用の PROJECT_ID に置き換えます。\n",
        "\n",
        "注意: --use_bqstorage_api は、BigQueryStorage API を使用してデータを取得し、それを使用する許可が与えられていることを確認するため、プロジェクトで有効化されていることを確認してください。https://cloud.google.com/bigquery/docs/reference/storage/#enabling_the_api\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "CVy3UkDgx2zi"
      },
      "outputs": [
        {
          "data": {
            "text/html": [
              "<div>\n",
              "<style scoped>\n",
              "    .dataframe tbody tr th:only-of-type {\n",
              "        vertical-align: middle;\n",
              "    }\n",
              "\n",
              "    .dataframe tbody tr th {\n",
              "        vertical-align: top;\n",
              "    }\n",
              "\n",
              "    .dataframe thead th {\n",
              "        text-align: right;\n",
              "    }\n",
              "</style>\n",
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th>age</th>\n",
              "      <th>workclass</th>\n",
              "      <th>fnlwgt</th>\n",
              "      <th>education</th>\n",
              "      <th>education_num</th>\n",
              "      <th>marital_status</th>\n",
              "      <th>occupation</th>\n",
              "      <th>relationship</th>\n",
              "      <th>race</th>\n",
              "      <th>gender</th>\n",
              "      <th>capital_gain</th>\n",
              "      <th>capital_loss</th>\n",
              "      <th>hours_per_week</th>\n",
              "      <th>native_country</th>\n",
              "      <th>income_bracket</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>0</th>\n",
              "      <td>39.0</td>\n",
              "      <td>Private</td>\n",
              "      <td>297847.0</td>\n",
              "      <td>9th</td>\n",
              "      <td>5.0</td>\n",
              "      <td>Married-civ-spouse</td>\n",
              "      <td>Other-service</td>\n",
              "      <td>Wife</td>\n",
              "      <td>Black</td>\n",
              "      <td>Female</td>\n",
              "      <td>3411.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>34.0</td>\n",
              "      <td>United-States</td>\n",
              "      <td>&lt;=50K</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>1</th>\n",
              "      <td>72.0</td>\n",
              "      <td>Private</td>\n",
              "      <td>74141.0</td>\n",
              "      <td>9th</td>\n",
              "      <td>5.0</td>\n",
              "      <td>Married-civ-spouse</td>\n",
              "      <td>Exec-managerial</td>\n",
              "      <td>Wife</td>\n",
              "      <td>Asian-Pac-Islander</td>\n",
              "      <td>Female</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>48.0</td>\n",
              "      <td>United-States</td>\n",
              "      <td>&gt;50K</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>2</th>\n",
              "      <td>45.0</td>\n",
              "      <td>Private</td>\n",
              "      <td>178215.0</td>\n",
              "      <td>9th</td>\n",
              "      <td>5.0</td>\n",
              "      <td>Married-civ-spouse</td>\n",
              "      <td>Machine-op-inspct</td>\n",
              "      <td>Wife</td>\n",
              "      <td>White</td>\n",
              "      <td>Female</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>40.0</td>\n",
              "      <td>United-States</td>\n",
              "      <td>&gt;50K</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>3</th>\n",
              "      <td>31.0</td>\n",
              "      <td>Private</td>\n",
              "      <td>86958.0</td>\n",
              "      <td>9th</td>\n",
              "      <td>5.0</td>\n",
              "      <td>Married-civ-spouse</td>\n",
              "      <td>Exec-managerial</td>\n",
              "      <td>Wife</td>\n",
              "      <td>White</td>\n",
              "      <td>Female</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>40.0</td>\n",
              "      <td>United-States</td>\n",
              "      <td>&lt;=50K</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>4</th>\n",
              "      <td>55.0</td>\n",
              "      <td>Private</td>\n",
              "      <td>176012.0</td>\n",
              "      <td>9th</td>\n",
              "      <td>5.0</td>\n",
              "      <td>Married-civ-spouse</td>\n",
              "      <td>Tech-support</td>\n",
              "      <td>Wife</td>\n",
              "      <td>White</td>\n",
              "      <td>Female</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>23.0</td>\n",
              "      <td>United-States</td>\n",
              "      <td>&lt;=50K</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>\n",
              "</div>"
            ],
            "text/plain": [
              "    age workclass    fnlwgt  ... hours_per_week  native_country income_bracket\n",
              "0  39.0   Private  297847.0  ...           34.0   United-States          <=50K\n",
              "1  72.0   Private   74141.0  ...           48.0   United-States           >50K\n",
              "2  45.0   Private  178215.0  ...           40.0   United-States           >50K\n",
              "3  31.0   Private   86958.0  ...           40.0   United-States          <=50K\n",
              "4  55.0   Private  176012.0  ...           23.0   United-States          <=50K\n",
              "\n",
              "[5 rows x 15 columns]"
            ]
          },
          "execution_count": 10,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "%%bigquery --use_bqstorage_api\n",
        "SELECT * FROM `<YOUR PROJECT>.census_dataset.census_training_table` LIMIT 5"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tOu-pCksYTtE"
      },
      "source": [
        "##BigQuery リーダーを使用して、TensorFlow DataSet に国税調査データを読み込む"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6_Gm8Mzh62yF"
      },
      "source": [
        "BigQuery から TensorFlow Dataset に国税調査データを読み出して変換します。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "NgPd9w5m06In"
      },
      "outputs": [],
      "source": [
        "from tensorflow.python.framework import ops\n",
        "from tensorflow.python.framework import dtypes\n",
        "from tensorflow_io.bigquery import BigQueryClient\n",
        "from tensorflow_io.bigquery import BigQueryReadSession\n",
        "  \n",
        "def transofrom_row(row_dict):\n",
        "  # Trim all string tensors\n",
        "  trimmed_dict = { column:\n",
        "                  (tf.strings.strip(tensor) if tensor.dtype == 'string' else tensor) \n",
        "                  for (column,tensor) in row_dict.items()\n",
        "                  }\n",
        "  # Extract feature column\n",
        "  income_bracket = trimmed_dict.pop('income_bracket')\n",
        "  # Convert feature column to 0.0/1.0\n",
        "  income_bracket_float = tf.cond(tf.equal(tf.strings.strip(income_bracket), '>50K'), \n",
        "                 lambda: tf.constant(1.0), \n",
        "                 lambda: tf.constant(0.0))\n",
        "  return (trimmed_dict, income_bracket_float)\n",
        "\n",
        "def read_bigquery(table_name):\n",
        "  tensorflow_io_bigquery_client = BigQueryClient()\n",
        "  read_session = tensorflow_io_bigquery_client.read_session(\n",
        "      \"projects/\" + PROJECT_ID,\n",
        "      PROJECT_ID, table_name, DATASET_ID,\n",
        "      list(field.name for field in CSV_SCHEMA \n",
        "           if not field.name in UNUSED_COLUMNS),\n",
        "      list(dtypes.double if field.field_type == 'FLOAT64' \n",
        "           else dtypes.string for field in CSV_SCHEMA\n",
        "           if not field.name in UNUSED_COLUMNS),\n",
        "      requested_streams=2)\n",
        "  \n",
        "  dataset = read_session.parallel_read_rows()\n",
        "  transformed_ds = dataset.map (transofrom_row)\n",
        "  return transformed_ds\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4_NlkxZt1rwR"
      },
      "outputs": [],
      "source": [
        "BATCH_SIZE = 32\n",
        "\n",
        "training_ds = read_bigquery(TRAINING_TABLE_ID).shuffle(10000).batch(BATCH_SIZE)\n",
        "eval_ds = read_bigquery(EVAL_TABLE_ID).batch(BATCH_SIZE)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "S_iXxNdSYsDO"
      },
      "source": [
        "##特徴量カラムを定義する"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "UClHDwcyhFky"
      },
      "outputs": [],
      "source": [
        "def get_categorical_feature_values(column):\n",
        "  query = 'SELECT DISTINCT TRIM({}) FROM `{}`.{}.{}'.format(column, PROJECT_ID, DATASET_ID, TRAINING_TABLE_ID)\n",
        "  client = bigquery.Client(project=PROJECT_ID)\n",
        "  dataset_ref = client.dataset(DATASET_ID)\n",
        "  job_config = bigquery.QueryJobConfig()\n",
        "  query_job = client.query(query, job_config=job_config)\n",
        "  result = query_job.to_dataframe()\n",
        "  return result.values[:,0]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "h9aAs1ZAtQr-"
      },
      "outputs": [],
      "source": [
        "from tensorflow import feature_column\n",
        "\n",
        "feature_columns = []\n",
        "\n",
        "# numeric cols\n",
        "for header in ['capital_gain', 'capital_loss', 'hours_per_week']:\n",
        "  feature_columns.append(feature_column.numeric_column(header))\n",
        "\n",
        "# categorical cols\n",
        "for header in ['workclass', 'marital_status', 'occupation', 'relationship',\n",
        "               'race', 'native_country', 'education']:\n",
        "  categorical_feature = feature_column.categorical_column_with_vocabulary_list(\n",
        "        header, get_categorical_feature_values(header))\n",
        "  categorical_feature_one_hot = feature_column.indicator_column(categorical_feature)\n",
        "  feature_columns.append(categorical_feature_one_hot)\n",
        "\n",
        "# bucketized cols\n",
        "age = feature_column.numeric_column('age')\n",
        "age_buckets = feature_column.bucketized_column(age, boundaries=[18, 25, 30, 35, 40, 45, 50, 55, 60, 65])\n",
        "feature_columns.append(age_buckets)\n",
        "\n",
        "feature_layer = tf.keras.layers.DenseFeatures(feature_columns)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HfO0bhXXY3GQ"
      },
      "source": [
        "##モデルを構築してトレーニングする"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GnKZKOQX7Qwx"
      },
      "source": [
        "モデルを構築します。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Sm-UsB5_zvt0"
      },
      "outputs": [],
      "source": [
        "Dense = tf.keras.layers.Dense\n",
        "model = tf.keras.Sequential(\n",
        "  [\n",
        "    feature_layer,\n",
        "      Dense(100, activation=tf.nn.relu, kernel_initializer='uniform'),\n",
        "      Dense(75, activation=tf.nn.relu),\n",
        "      Dense(50, activation=tf.nn.relu),\n",
        "      Dense(25, activation=tf.nn.relu),\n",
        "      Dense(1, activation=tf.nn.sigmoid)\n",
        "  ])\n",
        "\n",
        "# Compile Keras model\n",
        "model.compile(\n",
        "    loss='binary_crossentropy', \n",
        "    metrics=['accuracy'])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "f8bSDfQd7T1n"
      },
      "source": [
        "モデルをトレーニングします。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "gPKrlFCN1y00"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "WARNING:tensorflow:Layer sequential is casting an input tensor from dtype float64 to the layer's dtype of float32, which is new behavior in TensorFlow 2.  The layer has dtype float32 because it's dtype defaults to floatx.\n",
            "\n",
            "If you intended to run this layer in float32, you can safely ignore this warning. If in doubt, this warning is likely only an issue if you are porting a TensorFlow 1.X model to TensorFlow 2.\n",
            "\n",
            "To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.\n",
            "\n",
            "WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow_core/python/feature_column/feature_column_v2.py:4276: IndicatorColumn._variable_shape (from tensorflow.python.feature_column.feature_column_v2) is deprecated and will be removed in a future version.\n",
            "Instructions for updating:\n",
            "The old _FeatureColumn APIs are being deprecated. Please use the new FeatureColumn APIs instead.\n",
            "WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow_core/python/feature_column/feature_column_v2.py:4331: VocabularyListCategoricalColumn._num_buckets (from tensorflow.python.feature_column.feature_column_v2) is deprecated and will be removed in a future version.\n",
            "Instructions for updating:\n",
            "The old _FeatureColumn APIs are being deprecated. Please use the new FeatureColumn APIs instead.\n",
            "Epoch 1/5\n",
            "1018/1018 [==============================] - 17s 17ms/step - loss: 0.5985 - accuracy: 0.8105\n",
            "Epoch 2/5\n",
            "1018/1018 [==============================] - 10s 10ms/step - loss: 0.3670 - accuracy: 0.8324\n",
            "Epoch 3/5\n",
            "1018/1018 [==============================] - 11s 10ms/step - loss: 0.3487 - accuracy: 0.8393\n",
            "Epoch 4/5\n",
            "1018/1018 [==============================] - 11s 10ms/step - loss: 0.3398 - accuracy: 0.8435\n",
            "Epoch 5/5\n",
            "1018/1018 [==============================] - 11s 11ms/step - loss: 0.3377 - accuracy: 0.8455\n"
          ]
        },
        {
          "data": {
            "text/plain": [
              "<tensorflow.python.keras.callbacks.History at 0x7f978f5b91d0>"
            ]
          },
          "execution_count": 17,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "model.fit(training_ds, epochs=5)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UgKKliXRZG_G"
      },
      "source": [
        "##モデルを評価する"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SgNd5DdU7TEW"
      },
      "source": [
        "モデルを評価します。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8eGHVkmI5LBT"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "509/509 [==============================] - 8s 15ms/step - loss: 0.3338 - accuracy: 0.8398\n",
            "Accuracy 0.8398452\n"
          ]
        }
      ],
      "source": [
        "loss, accuracy = model.evaluate(eval_ds)\n",
        "print(\"Accuracy\", accuracy)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qBIWoGMP7bj1"
      },
      "source": [
        "ランダムなサンプルをいくつか評価します。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "aMou1t1xngXP"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "array([[0.5541261],\n",
              "       [0.6209938]], dtype=float32)"
            ]
          },
          "execution_count": 19,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "sample_x = {\n",
        "    'age' : np.array([56, 36]), \n",
        "    'workclass': np.array(['Local-gov', 'Private']), \n",
        "    'education': np.array(['Bachelors', 'Bachelors']), \n",
        "    'marital_status': np.array(['Married-civ-spouse', 'Married-civ-spouse']), \n",
        "    'occupation': np.array(['Tech-support', 'Other-service']), \n",
        "    'relationship': np.array(['Husband', 'Husband']), \n",
        "    'race': np.array(['White', 'Black']), \n",
        "    'gender': np.array(['Male', 'Male']), \n",
        "    'capital_gain': np.array([0, 7298]), \n",
        "    'capital_loss': np.array([0, 0]), \n",
        "    'hours_per_week': np.array([40, 36]), \n",
        "    'native_country': np.array(['United-States', 'United-States'])\n",
        "  }\n",
        "\n",
        "model.predict(sample_x)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UhNtHfuxCGVy"
      },
      "source": [
        "## リソース\n",
        "\n",
        "- [Google Cloud BigQuery の概要](https://github.com/tensorflow/io/blob/master/tensorflow_io/bigquery/README.md)\n",
        "- [AI プラットフォームで Keras を使用したトレーニングと予測](https://colab.sandbox.google.com/github/GoogleCloudPlatform/cloudml-samples/blob/master/notebooks/tensorflow/getting-started-keras.ipynb)"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [
        "Tce3stUlHN0L"
      ],
      "name": "bigquery.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
